# encoding:utf-8
from src import config
from src.log import *
from src.model.model import *
import openai
import time
import traceback
import tiktoken

user_session = dict()



# OpenAI对话模型API (可用)
class ChatGPTModel(Model):
    def __init__(self):
        #print(f'ChatGPTModel {config.MODEL_KEYS}')
        api_key= config.MODEL_KEYS[MODEL_CHATGPT]['api_key']
        api_base= config.MODEL_KEYS[MODEL_CHATGPT].get('api_base', None)
        openai.api_key = api_key
        if api_base:
            openai.api_base = api_base
        
    def reply(self, query, context=None, max_retry_count=3):
        #Logger.info("[CHATGPT] query={}".format(query))
        retry_count = max_retry_count
        while retry_count > 0:
            try:
                reply_content = self.__reply_text(query)
                #Logger.info("[CHATGPT] reply_content={}".format(reply_content))
                return reply_content
            
            except openai.error.APIConnectionError as e:
                traceback.print_exc()
                retry_count-=1
                try:
                    time.sleep(4)
                except:
                    pass   
                continue
            except openai.error.Timeout as e:
                traceback.print_exc()
                retry_count-=1
                try:
                    time.sleep(4)
                except:
                    pass   
                continue
            except openai.error.RateLimitError as e:
                traceback.print_exc()
                retry_count-=1
                try:
                    time.sleep(4)
                except:
                    pass   
                continue
            except Exception as e:
                traceback.print_exc()
                break
        return None

    
    
    def build_messages(self, prompt):
        messages=[]
        messages.append({"role": "system", "content": "You are a helpful assistant."}),
        messages.append({"role": "user", "content": prompt})
        return messages
    # def max_token_count(self, messages):
    #     return 4096 - self.get_token_count(messages)
    
    def __reply_text(self, prompt):
        messages=self.build_messages(prompt)
        response = openai.ChatCompletion.create(
                model= "gpt-3.5-turbo-0301",  # 对话模型的名称 model==时输入和输出的总数=4096
                messages=messages,
                temperature=0.5,  # 熵值，在[0,1]之间，越大表示选取的候选词越随机，回复越具有不确定性，建议和top_p参数二选一使用，创意性任务越大越好，精确性任务越小越好
                #max_tokens=4096-get_prompt_token(prompt),  # 回复最大的token数
                #top_p=model_conf(const.OPEN_AI).get("top_p", 0.7),,  #候选词列表。0.7 意味着只考虑前70%候选词的标记，建议和temperature参数二选一使用
                frequency_penalty = 0.0,  # [-2,2]之间，该值越大则越降低模型一行中的重复用词，更倾向于产生不同的内容
                presence_penalty=1.0  # [-2,2]之间，该值越大则越不受输入限制，将鼓励模型生成输入中不存在的新词，更倾向于产生不同的内容
                )
        #reply_content = response.choices[0]['message']['content']
        #used_token = response['usage']['total_tokens']
        #Logger.debug(response)
        #Logger.info("[CHATGPT] reply={}", reply_content)
        return response.choices[0]['message']['content']
            

    